{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#| eval: false\n",
    "! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from __future__ import annotations\n",
    "from fastai.basics import *\n",
    "from fastai.tabular.core import *\n",
    "from fastai.tabular.model import *\n",
    "from fastai.tabular.data import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp tabular.learner"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tabular learner\n",
    "\n",
    "> The function to immediately get a `Learner` ready to train for tabular data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The main function you probably want to use in this module is `tabular_learner`. It will automatically create a `TabularModel` suitable for your data and infer the right loss function. See the [tabular tutorial](http://docs.fast.ai/tutorial.tabular.html) for an example of use in context."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Main functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "class TabularLearner(Learner):\n",
    "    \"`Learner` for tabular data\"\n",
    "    def predict(self, \n",
    "        row:pd.Series, # Features to be predicted\n",
    "    ):\n",
    "        \"Predict on a single sample\"\n",
    "        dl = self.dls.test_dl(row.to_frame().T)\n",
    "        dl.dataset.conts = dl.dataset.conts.astype(np.float32)\n",
    "        inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n",
    "        b = (*tuplify(inp),*tuplify(dec_preds))\n",
    "        full_dec = self.dls.decode(b)\n",
    "        return full_dec,dec_preds[0],preds[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/learner.py#L16){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularLearner\n",
       "\n",
       ">      TabularLearner (dls:fastai.data.core.DataLoaders, model:Callable,\n",
       ">                      loss_func:Optional[Callable]=None, opt_func:fastai.optimi\n",
       ">                      zer.Optimizer|fastai.optimizer.OptimWrapper=<function\n",
       ">                      Adam>, lr:float|slice=0.001, splitter:Callable=<function\n",
       ">                      trainable_params>, cbs:fastai.callback.core.Callback|coll\n",
       ">                      ections.abc.MutableSequence|None=None, metrics:Union[Call\n",
       ">                      able,collections.abc.MutableSequence,NoneType]=None,\n",
       ">                      path:str|pathlib.Path|None=None,\n",
       ">                      model_dir:str|pathlib.Path='models',\n",
       ">                      wd:float|int|None=None, wd_bn_bias:bool=False,\n",
       ">                      train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95),\n",
       ">                      default_cbs:bool=True)\n",
       "\n",
       "*`Learner` for tabular data*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| dls | DataLoaders |  | `DataLoaders` containing fastai or PyTorch `DataLoader`s |\n",
       "| model | Callable |  | PyTorch model for training or inference |\n",
       "| loss_func | Optional | None | Loss function. Defaults to `dls` loss |\n",
       "| opt_func | fastai.optimizer.Optimizer \\| fastai.optimizer.OptimWrapper | Adam | Optimization function for training |\n",
       "| lr | float \\| slice | 0.001 | Default learning rate |\n",
       "| splitter | Callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |\n",
       "| cbs | fastai.callback.core.Callback \\| collections.abc.MutableSequence \\| None | None | `Callback`s to add to `Learner` |\n",
       "| metrics | Union | None | `Metric`s to calculate on validation set |\n",
       "| path | str \\| pathlib.Path \\| None | None | Parent directory to save, load, and export models. Defaults to `dls` `path` |\n",
       "| model_dir | str \\| pathlib.Path | models | Subdirectory to save and load models |\n",
       "| wd | float \\| int \\| None | None | Default weight decay |\n",
       "| wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |\n",
       "| train_bn | bool | True | Train frozen normalization layers |\n",
       "| moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |\n",
       "| default_cbs | bool | True | Include default `Callback`s |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/learner.py#L16){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularLearner\n",
       "\n",
       ">      TabularLearner (dls:fastai.data.core.DataLoaders, model:Callable,\n",
       ">                      loss_func:Optional[Callable]=None, opt_func:fastai.optimi\n",
       ">                      zer.Optimizer|fastai.optimizer.OptimWrapper=<function\n",
       ">                      Adam>, lr:float|slice=0.001, splitter:Callable=<function\n",
       ">                      trainable_params>, cbs:fastai.callback.core.Callback|coll\n",
       ">                      ections.abc.MutableSequence|None=None, metrics:Union[Call\n",
       ">                      able,collections.abc.MutableSequence,NoneType]=None,\n",
       ">                      path:str|pathlib.Path|None=None,\n",
       ">                      model_dir:str|pathlib.Path='models',\n",
       ">                      wd:float|int|None=None, wd_bn_bias:bool=False,\n",
       ">                      train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95),\n",
       ">                      default_cbs:bool=True)\n",
       "\n",
       "*`Learner` for tabular data*\n",
       "\n",
       "|    | **Type** | **Default** | **Details** |\n",
       "| -- | -------- | ----------- | ----------- |\n",
       "| dls | DataLoaders |  | `DataLoaders` containing fastai or PyTorch `DataLoader`s |\n",
       "| model | Callable |  | PyTorch model for training or inference |\n",
       "| loss_func | Optional | None | Loss function. Defaults to `dls` loss |\n",
       "| opt_func | fastai.optimizer.Optimizer \\| fastai.optimizer.OptimWrapper | Adam | Optimization function for training |\n",
       "| lr | float \\| slice | 0.001 | Default learning rate |\n",
       "| splitter | Callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |\n",
       "| cbs | fastai.callback.core.Callback \\| collections.abc.MutableSequence \\| None | None | `Callback`s to add to `Learner` |\n",
       "| metrics | Union | None | `Metric`s to calculate on validation set |\n",
       "| path | str \\| pathlib.Path \\| None | None | Parent directory to save, load, and export models. Defaults to `dls` `path` |\n",
       "| model_dir | str \\| pathlib.Path | models | Subdirectory to save and load models |\n",
       "| wd | float \\| int \\| None | None | Default weight decay |\n",
       "| wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |\n",
       "| train_bn | bool | True | Train frozen normalization layers |\n",
       "| moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |\n",
       "| default_cbs | bool | True | Include default `Callback`s |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TabularLearner, title_level=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It works exactly as a normal `Learner`, the only difference is that it implements a `predict` method specific to work on a row of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@delegates(Learner.__init__)\n",
    "def tabular_learner(\n",
    "        dls:TabularDataLoaders,\n",
    "        layers:list=None, # Size of the layers generated by `LinBnDrop`\n",
    "        emb_szs:list=None, # Tuples of `n_unique, embedding_size` for all categorical features\n",
    "        config:dict=None, # Config params for TabularModel from `tabular_config`\n",
    "        n_out:int=None, # Final output size of the model\n",
    "        y_range:Tuple=None, # Low and high for the final sigmoid function\n",
    "        **kwargs\n",
    "):\n",
    "    \"Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params.\"\n",
    "    if config is None: config = tabular_config()\n",
    "    if layers is None: layers = [200,100]\n",
    "    to = dls.train_ds\n",
    "    emb_szs = get_emb_sz(dls.train_ds, {} if emb_szs is None else emb_szs)\n",
    "    if n_out is None: n_out = get_c(dls)\n",
    "    assert n_out, \"`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`\"\n",
    "    if y_range is None and 'y_range' in config: y_range = config.pop('y_range')\n",
    "    model = TabularModel(emb_szs, len(dls.cont_names), n_out, layers, y_range=y_range, **config)\n",
    "    return TabularLearner(dls, model, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your data was built with fastai, you probably won't need to pass anything to `emb_szs` unless you want to change the default of the library (produced by `get_emb_sz`), same for `n_out` which should be automatically inferred. `layers` will default to `[200,100]` and is passed to `TabularModel` along with the `config`.\n",
    "\n",
    "Use `tabular_config` to create a `config` and customize the model used. There is just easy access to `y_range` because this argument is often used.\n",
    "\n",
    "All the other arguments are passed to `Learner`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = untar_data(URLs.ADULT_SAMPLE)\n",
    "df = pd.read_csv(path/'adult.csv')\n",
    "cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']\n",
    "cont_names = ['age', 'fnlwgt', 'education-num']\n",
    "procs = [Categorify, FillMissing, Normalize]\n",
    "dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, \n",
    "                                 y_names=\"salary\", valid_idx=list(range(800,1000)), bs=64)\n",
    "learn = tabular_learner(dls)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/learner.py#L18){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularLearner.predict\n",
       "\n",
       ">      TabularLearner.predict (row:pandas.core.series.Series)\n",
       "\n",
       "*Predict on a single sample*\n",
       "\n",
       "|    | **Type** | **Details** |\n",
       "| -- | -------- | ----------- |\n",
       "| row | Series | Features to be predicted |"
      ],
      "text/plain": [
       "---\n",
       "\n",
       "[source](https://github.com/fastai/fastai/blob/main/fastai/tabular/learner.py#L18){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
       "\n",
       "### TabularLearner.predict\n",
       "\n",
       ">      TabularLearner.predict (row:pandas.core.series.Series)\n",
       "\n",
       "*Predict on a single sample*\n",
       "\n",
       "|    | **Type** | **Details** |\n",
       "| -- | -------- | ----------- |\n",
       "| row | Series | Features to be predicted |"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "show_doc(TabularLearner.predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can pass in an individual row of data into our `TabularLearner`'s `predict` method. It's output is slightly different from the other `predict` methods, as this one will always return the input as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "row, clas, probs = learn.predict(df.iloc[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>workclass</th>\n",
       "      <th>education</th>\n",
       "      <th>marital-status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>education-num_na</th>\n",
       "      <th>age</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education-num</th>\n",
       "      <th>salary</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Private</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>#na#</td>\n",
       "      <td>Wife</td>\n",
       "      <td>White</td>\n",
       "      <td>False</td>\n",
       "      <td>49.0</td>\n",
       "      <td>101320.001685</td>\n",
       "      <td>12.0</td>\n",
       "      <td>&lt;50k</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "row.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0), tensor([0.5264, 0.4736]))"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clas, probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "#test y_range is passed\n",
    "learn = tabular_learner(dls, y_range=(0,32))\n",
    "assert isinstance(learn.model.layers[-1], SigmoidRange)\n",
    "test_eq(learn.model.layers[-1].low, 0)\n",
    "test_eq(learn.model.layers[-1].high, 32)\n",
    "\n",
    "learn = tabular_learner(dls, config = tabular_config(y_range=(0,32)))\n",
    "assert isinstance(learn.model.layers[-1], SigmoidRange)\n",
    "test_eq(learn.model.layers[-1].low, 0)\n",
    "test_eq(learn.model.layers[-1].high, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "@dispatch\n",
    "def show_results(x:Tabular, y:Tabular, samples, outs, ctxs=None, max_n=10, **kwargs):\n",
    "    df = x.all_cols[:max_n]\n",
    "    for n in x.y_names: df[n+'_pred'] = y[n][:max_n].values\n",
    "    display_df(df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev import nbdev_export\n",
    "nbdev_export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
